InfoGAN

Please note that this is an optional notebook meant to introduce more advanced concepts. If you’re up for a challenge, take a look and don’t worry if you can’t follow everything. There is no code to implement—only some cool code for you to learn and run!

Goals

In this notebook, you're going to learn about InfoGAN in order to generate disentangled outputs, based on the paper, InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets by Chen et. al. While there are many approaches to disentanglement, this is one of the more widely used and better known.

InfoGAN can be understood like this: you want to separate your model into two parts: $z$, corresponding to truly random noise, and $c$ corresponding to the "latent code." The latent code $c$ which can be thought of as a "hidden" condition in a conditional generator, and you'd like it to have an interpretable meaning.

Now, you'll likely immediately wonder, how do they get $c$, which is just some random set of numbers, to be more interpretable than any dimension in a typical GAN? The answer is "mutual information": essentially, you would like each dimension of the latent code to be as obvious a function as possible of the generated images. Read on for a more thorough theoretical and practical treatment.

Formally: Variational Lower Bound

The information entropy) ${H} (X)=-\sum _{i=1}^{n}{P(x_{i})\log P (x_{i})}$ can be understood to the amount of "information" in the distribution $X$. For example, the information entropy of $n$ fair coins is $n$ bits. You've also seen a similar equation before: the cross-entropy loss. Moreover, mutual information $I(X;Y) = H(X) - H(X\vert Y)$, which the authors of InfoGAN describe as (intuitively) the "reduction of uncertainty in $X$ when $Y$ is observed."

In InfoGAN, you'd like to maximize $I(c; G(z, c))$, the mutual information between the latent code $c$ and the generated images $G(z, c)$. Since it's difficult to know $P(c | G(z, c))$, you add a second output to the discriminator to predict $P(c | G(z, c))$.

Let $\Delta = D_{KL}(P(\cdot|x) \Vert Q(\cdot|x))$, the Kullback-Leibler_divergence between the true and approximate distribution. Then, based on Equation 4 in the paper, the mutual information has the following lower bound: $$\begin{split} I(c; G(z, c)) & = H(c) - H(c|G(z, c)) \\ & = {\mathbb{E}}_{x \sim G(z, c)} [ {\mathbb{E}}_{c' \sim P(c, x)} \log P(c' | x) ] + H(c) \textit{ (by definition of H)}\\ & = {\mathbb{E}}_{x \sim G(z, c)} [\Delta + {\mathbb{E}}_{c' \sim P(c, x)} \log Q(c' | x) ] + H(c) \textit{ (approximation error)}\\ & \geq {\mathbb{E}}_{x \sim G(z, c)} [{\mathbb{E}}_{c' \sim P(c, x)} \log Q(c' | x) ] + H(c) \textit{ (KL divergence is non-negative)}\\ \end{split} $$

For a given latent code distribution, $H(c)$ is fixed, so the following makes a good loss:

$${\mathbb{E}}_{x \sim G(z, c)} [{\mathbb{E}}_{c' \sim P(c, x)} \log Q(c' | x) ]$$

Which is the mean cross entropy loss of the approximation over the generator's images.

Updating the Minimax Game

A vanilla generator and discriminator follow a minimax game: $\displaystyle \min_{G} \max_{D} V(D, G) = \mathbb{E}(\log D(x)) + \mathbb{E}(\log (1 - D(G(z))))$.

To encourage mutual information, this game is updated for $Q$ to maximize mutual information: $\displaystyle \min_{G, Q} \max_{D} V(D, G) - \lambda I(c; G(z, c))$

Implementing InfoGAN

For this notebook, you'll be using the MNIST dataset again.

You will begin by importing the necessary libraries and building the generator and discriminator. The generator will be the same as before, but the discriminator will be modified with more dimensions in its output.

Packages and Visualization

In [1]:
import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
torch.manual_seed(0) # Set for our testing purposes, please do not change!

def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28), nrow=5, show=True):
    '''
    Function for visualizing images: Given a tensor of images, number of images, and
    size per image, plots and prints the images in an uniform grid.
    '''
    image_tensor = (image_tensor + 1) / 2
    image_unflat = image_tensor.detach().cpu()
    image_grid = make_grid(image_unflat[:num_images], nrow=nrow)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    if show:
        plt.show()

Generator and Noise

In [2]:
class Generator(nn.Module):
    '''
    Generator Class
    Values:
        input_dim: the dimension of the input vector, a scalar
        im_chan: the number of channels in the images, fitted for the dataset used, a scalar
              (MNIST is black-and-white, so 1 channel is your default)
        hidden_dim: the inner dimension, a scalar
    '''
    def __init__(self, input_dim=10, im_chan=1, hidden_dim=64):
        super(Generator, self).__init__()
        self.input_dim = input_dim
        # Build the neural network
        self.gen = nn.Sequential(
            self.make_gen_block(input_dim, hidden_dim * 4),
            self.make_gen_block(hidden_dim * 4, hidden_dim * 2, kernel_size=4, stride=1),
            self.make_gen_block(hidden_dim * 2, hidden_dim),
            self.make_gen_block(hidden_dim, im_chan, kernel_size=4, final_layer=True),
        )

    def make_gen_block(self, input_channels, output_channels, kernel_size=3, stride=2, final_layer=False):
        '''
        Function to return a sequence of operations corresponding to a generator block of DCGAN;
        a transposed convolution, a batchnorm (except in the final layer), and an activation.
        Parameters:
            input_channels: how many channels the input feature representation has
            output_channels: how many channels the output feature representation should have
            kernel_size: the size of each convolutional filter, equivalent to (kernel_size, kernel_size)
            stride: the stride of the convolution
            final_layer: a boolean, true if it is the final layer and false otherwise 
                      (affects activation and batchnorm)
        '''
        if not final_layer:
            return nn.Sequential(
                nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
                nn.BatchNorm2d(output_channels),
                nn.ReLU(inplace=True),
            )
        else:
            return nn.Sequential(
                nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
                nn.Tanh(),
            )

    def forward(self, noise):
        '''
        Function for completing a forward pass of the generator: Given a noise tensor, 
        returns generated images.
        Parameters:
            noise: a noise tensor with dimensions (n_samples, input_dim)
        '''
        x = noise.view(len(noise), self.input_dim, 1, 1)
        return self.gen(x)

def get_noise(n_samples, input_dim, device='cpu'):
    '''
    Function for creating noise vectors: Given the dimensions (n_samples, input_dim)
    creates a tensor of that shape filled with random numbers from the normal distribution.
    Parameters:
        n_samples: the number of samples to generate, a scalar
        input_dim: the dimension of the input vector, a scalar
        device: the device type
    '''
    return torch.randn(n_samples, input_dim, device=device)

InfoGAN Discriminator

You update the final layer to predict a distribution for $c$ from $x$, alongside the traditional discriminator output. Since you're assuming a normal prior in this assignment, you output a mean and a log-variance prediction.

In [3]:
class Discriminator(nn.Module):
    '''
    Discriminator Class
    Values:
      im_chan: the number of channels in the images, fitted for the dataset used, a scalar
            (MNIST is black-and-white, so 1 channel is your default)
      hidden_dim: the inner dimension, a scalar
      c_dim: the number of latent code dimensions - 
    '''
    def __init__(self, im_chan=1, hidden_dim=64, c_dim=10):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            self.make_disc_block(im_chan, hidden_dim),
            self.make_disc_block(hidden_dim, hidden_dim * 2),
        )
        self.d_layer = self.make_disc_block(hidden_dim * 2, 1, final_layer=True)
        self.q_layer = nn.Sequential(
            self.make_disc_block(hidden_dim * 2, hidden_dim * 2),
            self.make_disc_block(hidden_dim * 2, 2 * c_dim, kernel_size=1, final_layer=True)
        )

    def make_disc_block(self, input_channels, output_channels, kernel_size=4, stride=2, final_layer=False):
        '''
        Function to return a sequence of operations corresponding to a discriminator block of the DCGAN; 
        a convolution, a batchnorm (except in the final layer), and an activation (except in the final layer).
        Parameters:
            input_channels: how many channels the input feature representation has
            output_channels: how many channels the output feature representation should have
            kernel_size: the size of each convolutional filter, equivalent to (kernel_size, kernel_size)
            stride: the stride of the convolution
            final_layer: a boolean, true if it is the final layer and false otherwise 
                      (affects activation and batchnorm)
        '''
        if not final_layer:
            return nn.Sequential(
                nn.Conv2d(input_channels, output_channels, kernel_size, stride),
                nn.BatchNorm2d(output_channels),
                nn.LeakyReLU(0.2, inplace=True),
            )
        else:
            return nn.Sequential(
                nn.Conv2d(input_channels, output_channels, kernel_size, stride),
            )

    def forward(self, image):
        '''
        Function for completing a forward pass of the discriminator: Given an image tensor, 
        returns a 1-dimension tensor representing fake/real.
        Parameters:
            image: a flattened image tensor with dimension (im_chan)
        '''
        intermediate_pred = self.disc(image)
        disc_pred = self.d_layer(intermediate_pred)
        q_pred = self.q_layer(intermediate_pred)
        return disc_pred.view(len(disc_pred), -1), q_pred.view(len(q_pred), -1)

Helper Functions

You can include some helper functions for conditional GANs:

In [4]:
def combine_vectors(x, y):
    '''
    Function for combining two vectors with shapes (n_samples, ?) and (n_samples, ?).
    Parameters:
      x: (n_samples, ?) the first vector. 
        This will be the noise vector of shape (n_samples, z_dim).
      y: (n_samples, ?) the second vector.
        Once again, in this assignment this will be the one-hot class vector 
        with the shape (n_samples, n_classes).
    '''
    combined = torch.cat([x.float(), y.float()], 1)
    return combined

Training

Let's include the same parameters from previous assignments, as well as a new c_dim dimension for the dimensionality of the InfoGAN latent code, a c_criterion, and its corresponding constant, c_lambda:

  • mnist_shape: the number of pixels in each MNIST image, which has dimensions 28 x 28 and one channel (because it's black-and-white) so 1 x 28 x 28
  • adv_criterion: the vanilla GAN loss function
  • c_criterion: the additional mutual information term
  • c_lambda: the weight on the c_criterion
  • n_epochs: the number of times you iterate through the entire dataset when training
  • z_dim: the dimension of the noise vector
  • c_dim: the dimension of the InfoGAN latent code
  • display_step: how often to display/visualize the images
  • batch_size: the number of images per forward/backward pass
  • lr: the learning rate
  • device: the device type
In [5]:
from torch.distributions.normal import Normal
adv_criterion = nn.BCEWithLogitsLoss()
c_criterion = lambda c_true, mean, logvar: Normal(mean, logvar.exp()).log_prob(c_true).mean()
c_lambda = 0.1
mnist_shape = (1, 28, 28)
n_epochs = 80
z_dim = 64
c_dim = 2
display_step = 500
batch_size = 128
# InfoGAN uses two different learning rates for the models
d_lr = 2e-4
g_lr = 1e-3
device = 'cuda'

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
])

dataloader = DataLoader(
    MNIST('.', download=True, transform=transform),
    batch_size=batch_size,
    shuffle=True)

You initialize your networks as usual - notice that there is no separate $Q$ network. There are a few "design" choices worth noting here:

  1. There are many possible choices for the distribution over the latent code. You use a Gaussian prior here, but a categorical (discrete) prior is also possible, and in fact it's possible to use them together. In this case, it's also possible to use different weights $\lambda$ on both prior distributions.
  2. You can calculate the mutual information explicitly, including $H(c)$ which you treat as constant here. You don't do that here since you're not comparing the mutual information of different parameterizations of the latent code.
  3. There are multiple ways to handle the $Q$ network - this code follows the original paper by treating it as part of the discriminator, sharing most weights, but it is also possible to simply initialize another network.
In [6]:
gen = Generator(input_dim=z_dim + c_dim).to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr=g_lr)
disc = Discriminator(im_chan=mnist_shape[0], c_dim=c_dim).to(device)
disc_opt = torch.optim.Adam(disc.parameters(), lr=d_lr)

def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)
gen = gen.apply(weights_init)
disc = disc.apply(weights_init)

Now let's get to training the networks:

In [7]:
cur_step = 0
generator_losses = []
discriminator_losses = []

for epoch in range(n_epochs):
    # Dataloader returns the batches and the labels
    for real, _ in tqdm(dataloader):
        cur_batch_size = len(real)
        # Flatten the batch of real images from the dataset
        real = real.to(device)

        c_labels = get_noise(cur_batch_size, c_dim, device=device)    
        ### Update discriminator ###
        # Zero out the discriminator gradients
        disc_opt.zero_grad()
        # Get noise corresponding to the current batch_size 
        fake_noise = get_noise(cur_batch_size, z_dim, device=device)
        # Combine the noise vectors and the one-hot labels for the generator
        noise_and_labels = combine_vectors(fake_noise, c_labels)
        # Generate the conditioned fake images
        fake = gen(noise_and_labels)
        
        # Get the discriminator's predictions
        disc_fake_pred, disc_q_pred = disc(fake.detach())
        disc_q_mean = disc_q_pred[:, :c_dim]
        disc_q_logvar = disc_q_pred[:, c_dim:]
        mutual_information = c_criterion(c_labels, disc_q_mean, disc_q_logvar)
        disc_real_pred, _ = disc(real)
        disc_fake_loss = adv_criterion(disc_fake_pred, torch.zeros_like(disc_fake_pred))
        disc_real_loss = adv_criterion(disc_real_pred, torch.ones_like(disc_real_pred))
        disc_loss = (disc_fake_loss + disc_real_loss) / 2 - c_lambda * mutual_information
        disc_loss.backward(retain_graph=True)
        disc_opt.step() 

        # Keep track of the average discriminator loss
        discriminator_losses += [disc_loss.item()]

        ### Update generator ###
        # Zero out the generator gradients
        gen_opt.zero_grad()

        disc_fake_pred, disc_q_pred = disc(fake)
        disc_q_mean = disc_q_pred[:, :c_dim]
        disc_q_logvar = disc_q_pred[:, c_dim:]
        mutual_information = c_criterion(c_labels, disc_q_mean, disc_q_logvar)
        gen_loss = adv_criterion(disc_fake_pred, torch.ones_like(disc_fake_pred)) - c_lambda * mutual_information
        gen_loss.backward()
        gen_opt.step()

        # Keep track of the generator losses
        generator_losses += [gen_loss.item()]

        if cur_step % display_step == 0 and cur_step > 0:
            gen_mean = sum(generator_losses[-display_step:]) / display_step
            disc_mean = sum(discriminator_losses[-display_step:]) / display_step
            print(f"Epoch {epoch}, step {cur_step}: Generator loss: {gen_mean}, discriminator loss: {disc_mean}")
            show_tensor_images(fake)
            show_tensor_images(real)
            step_bins = 20
            x_axis = sorted([i * step_bins for i in range(len(generator_losses) // step_bins)] * step_bins)
            num_examples = (len(generator_losses) // step_bins) * step_bins
            plt.plot(
                range(num_examples // step_bins), 
                torch.Tensor(generator_losses[:num_examples]).view(-1, step_bins).mean(1),
                label="Generator Loss"
            )
            plt.plot(
                range(num_examples // step_bins), 
                torch.Tensor(discriminator_losses[:num_examples]).view(-1, step_bins).mean(1),
                label="Discriminator Loss"
            )
            plt.legend()
            plt.show()
        cur_step += 1

Epoch 1, step 500: Generator loss: 0.8182289716005325, discriminator loss: 0.6970322011113167

Epoch 2, step 1000: Generator loss: 1.7924996489286422, discriminator loss: 0.26028732980787755

Epoch 3, step 1500: Generator loss: 2.6127100493907927, discriminator loss: 0.12645398038998246

Epoch 4, step 2000: Generator loss: 2.5406293122768404, discriminator loss: 0.22449365544319153

Epoch 5, step 2500: Generator loss: 2.415271957397461, discriminator loss: 0.26998663461208344

Epoch 6, step 3000: Generator loss: 2.324103059053421, discriminator loss: 0.27589657099545

Epoch 7, step 3500: Generator loss: 2.33047908949852, discriminator loss: 0.2694139747768641

Epoch 8, step 4000: Generator loss: 2.332132910490036, discriminator loss: 0.28993756261467934

Epoch 9, step 4500: Generator loss: 2.3577181141376498, discriminator loss: 0.2747534710764885

Epoch 10, step 5000: Generator loss: 2.339188297748566, discriminator loss: 0.28916628383100035

Epoch 11, step 5500: Generator loss: 2.1097206530570984, discriminator loss: 0.35928141298890115

Epoch 12, step 6000: Generator loss: 1.8833935358524323, discriminator loss: 0.36148455685377123

Epoch 13, step 6500: Generator loss: 1.8995823158025742, discriminator loss: 0.3979255060851574

Epoch 14, step 7000: Generator loss: 1.6693327586650848, discriminator loss: 0.39940346783399583

Epoch 15, step 7500: Generator loss: 1.6503913009166717, discriminator loss: 0.4311146695613861


Epoch 17, step 8000: Generator loss: 1.6739469344615936, discriminator loss: 0.4301695189476013

Epoch 18, step 8500: Generator loss: 1.634433900117874, discriminator loss: 0.4555605404376984

Epoch 19, step 9000: Generator loss: 1.6495194882154465, discriminator loss: 0.4601502610146999

Epoch 20, step 9500: Generator loss: 1.4434897894859313, discriminator loss: 0.49609006094932556

Epoch 21, step 10000: Generator loss: 1.4628067816495895, discriminator loss: 0.497731826364994

Epoch 22, step 10500: Generator loss: 1.457012984752655, discriminator loss: 0.5092470208704472

Epoch 23, step 11000: Generator loss: 1.4228946658372879, discriminator loss: 0.5129678817987442

Epoch 24, step 11500: Generator loss: 1.425648075580597, discriminator loss: 0.5259494391679764

Epoch 25, step 12000: Generator loss: 1.3924782351255416, discriminator loss: 0.5162326992154121

Epoch 26, step 12500: Generator loss: 1.40210223031044, discriminator loss: 0.5240750223994255

Epoch 27, step 13000: Generator loss: 1.3496780635118484, discriminator loss: 0.5310935707688331

Epoch 28, step 13500: Generator loss: 1.3106019917726517, discriminator loss: 0.5224199027419091

Epoch 29, step 14000: Generator loss: 1.3660474253892898, discriminator loss: 0.5262772083282471

Epoch 30, step 14500: Generator loss: 1.3374567056894302, discriminator loss: 0.531988255918026

Epoch 31, step 15000: Generator loss: 1.4002158173322679, discriminator loss: 0.5343616697788238


Epoch 33, step 15500: Generator loss: 1.3050618778467178, discriminator loss: 0.5232404959201813

Epoch 34, step 16000: Generator loss: 1.3428079401254653, discriminator loss: 0.524510620355606

Epoch 35, step 16500: Generator loss: 1.3471796796917916, discriminator loss: 0.5314510517716408

Epoch 36, step 17000: Generator loss: 1.3235193297863006, discriminator loss: 0.5291766377687455

Epoch 37, step 17500: Generator loss: 1.2901832588911057, discriminator loss: 0.5231917806267739

Epoch 38, step 18000: Generator loss: 1.402871264219284, discriminator loss: 0.5196232446432114

Epoch 39, step 18500: Generator loss: 1.3553507169485093, discriminator loss: 0.5339593841433525

Epoch 40, step 19000: Generator loss: 1.2890181401968002, discriminator loss: 0.5233975170254708

Epoch 41, step 19500: Generator loss: 1.2960289684534072, discriminator loss: 0.5264109694361687

Epoch 42, step 20000: Generator loss: 1.3218628760576248, discriminator loss: 0.5243060600757599

Epoch 43, step 20500: Generator loss: 1.350328610420227, discriminator loss: 0.5263542796969414

Epoch 44, step 21000: Generator loss: 1.3095761449337007, discriminator loss: 0.5240265983343124

Epoch 45, step 21500: Generator loss: 1.3215627145767213, discriminator loss: 0.5191929469704628

Epoch 46, step 22000: Generator loss: 1.327082029223442, discriminator loss: 0.5137711240053177

Epoch 47, step 22500: Generator loss: 1.298170992732048, discriminator loss: 0.5127231276631355


Epoch 49, step 23000: Generator loss: 1.323118221282959, discriminator loss: 0.5130633860826492

Epoch 50, step 23500: Generator loss: 1.336889844417572, discriminator loss: 0.5096884853243828

Epoch 51, step 24000: Generator loss: 1.3184827300310136, discriminator loss: 0.5030064074993134

Epoch 52, step 24500: Generator loss: 1.3355253584384918, discriminator loss: 0.5012828984856605

Epoch 53, step 25000: Generator loss: 1.3238797607421875, discriminator loss: 0.49775288832187653

Epoch 54, step 25500: Generator loss: 1.2969990675449372, discriminator loss: 0.5070032252669334

Epoch 55, step 26000: Generator loss: 1.2986541981697082, discriminator loss: 0.5025589602589607

Epoch 56, step 26500: Generator loss: 1.2987225022315978, discriminator loss: 0.5045067887306214

Epoch 57, step 27000: Generator loss: 1.3197853413820266, discriminator loss: 0.5004841610193252

Epoch 58, step 27500: Generator loss: 1.309844874739647, discriminator loss: 0.5035973317623138

Epoch 59, step 28000: Generator loss: 1.3151807837486267, discriminator loss: 0.499425533592701

Epoch 60, step 28500: Generator loss: 1.3056570110321044, discriminator loss: 0.5076552347540856

Epoch 61, step 29000: Generator loss: 1.312231298327446, discriminator loss: 0.5021820602416992

Epoch 62, step 29500: Generator loss: 1.329822805762291, discriminator loss: 0.511491660118103

Epoch 63, step 30000: Generator loss: 1.2953895548582077, discriminator loss: 0.5159559162855148


Epoch 65, step 30500: Generator loss: 1.2969535572528839, discriminator loss: 0.5165752695798874

Epoch 66, step 31000: Generator loss: 1.3206324875354767, discriminator loss: 0.5129701662659645

Epoch 67, step 31500: Generator loss: 1.2835827662944794, discriminator loss: 0.5127375834584236

Epoch 68, step 32000: Generator loss: 1.3052190783023834, discriminator loss: 0.5161510946154595

Epoch 69, step 32500: Generator loss: 1.2985972447395324, discriminator loss: 0.5168027797341347

Epoch 70, step 33000: Generator loss: 1.2808287208080291, discriminator loss: 0.5187389892339707

Epoch 71, step 33500: Generator loss: 1.2502468209266662, discriminator loss: 0.512817432820797

Epoch 72, step 34000: Generator loss: 1.2812227767705917, discriminator loss: 0.5022463051080703

Epoch 73, step 34500: Generator loss: 1.3140319393873214, discriminator loss: 0.5133160338401794

Epoch 74, step 35000: Generator loss: 1.296096970438957, discriminator loss: 0.5081236431002617

Epoch 75, step 35500: Generator loss: 1.2747619363069533, discriminator loss: 0.5088420104980469

Epoch 76, step 36000: Generator loss: 1.282422392487526, discriminator loss: 0.5058259936571121

Epoch 77, step 36500: Generator loss: 1.2950030229091645, discriminator loss: 0.5158454319238662

Epoch 78, step 37000: Generator loss: 1.2952934156656266, discriminator loss: 0.5081967522501946

Epoch 79, step 37500: Generator loss: 1.2680251882076263, discriminator loss: 0.5065745185613633

Exploration

You can do a bit of exploration now!

In [8]:
# Before you explore, you should put the generator
# in eval mode, both in general and so that batch norm
# doesn't cause you issues and is using its eval statistics
gen = gen.eval()

Changing the Latent Code Vector

You can generate some numbers with your new model! You can add interpolation as well to make it more interesting.

So starting from a image, you will produce intermediate images that look more and more like the ending image until you get to the final image. Your're basically morphing one image into another. You can choose what these two images will be using your conditional GAN.

In [9]:
import math

### Change me! ###
n_interpolation = 9 # Choose the interpolation: how many intermediate images you want + 2 (for the start and end image)

def interpolate_class(n_view=5):
    interpolation_noise = get_noise(n_view, z_dim, device=device).repeat(n_interpolation, 1)
    first_label = get_noise(1, c_dim).repeat(n_view, 1)[None, :]
    second_label = first_label.clone()
    first_label[:, :, 0] =  -2
    second_label[:, :, 0] =  2
    

    # Calculate the interpolation vector between the two labels
    percent_second_label = torch.linspace(0, 1, n_interpolation)[:, None, None]
    interpolation_labels = first_label * (1 - percent_second_label) + second_label * percent_second_label
    interpolation_labels = interpolation_labels.view(-1, c_dim)

    # Combine the noise and the labels
    noise_and_labels = combine_vectors(interpolation_noise, interpolation_labels.to(device))
    fake = gen(noise_and_labels)
    show_tensor_images(fake, num_images=n_interpolation * n_view, nrow=n_view, show=False)

plt.figure(figsize=(8, 8))
interpolate_class()
_ = plt.axis('off')

You can also visualize the impact of pairwise changes of the latent code for a given noise vector.

In [10]:
import math

### Change me! ###
n_interpolation = 8 # Choose the interpolation: how many intermediate images you want + 2 (for the start and end image)

def interpolate_class():
    interpolation_noise = get_noise(1, z_dim, device=device).repeat(n_interpolation * n_interpolation, 1)
    first_label = get_noise(1, c_dim).repeat(n_interpolation * n_interpolation, 1)
    
    # Calculate the interpolation vector between the two labels
    first_label = torch.linspace(-2, 2, n_interpolation).repeat(n_interpolation)
    second_label = torch.linspace(-2, 2, n_interpolation).repeat_interleave(n_interpolation)
    interpolation_labels = torch.stack([first_label, second_label], dim=1) 

    # Combine the noise and the labels
    noise_and_labels = combine_vectors(interpolation_noise, interpolation_labels.to(device))
    fake = gen(noise_and_labels)
    show_tensor_images(fake, num_images=n_interpolation * n_interpolation, nrow=n_interpolation, show=False)

plt.figure(figsize=(8, 8))
interpolate_class()
_ = plt.axis('off')